from typing import Any, Optional, NamedTuple, Iterable, Callable
import jax.numpy as jnp
import jax
import haiku as hk
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import math
import numpy as np
import optax
import functools
from tqdm import tqdm
import tensorflow_probability.substrates.jax as tfp

from regression import make_hetstat_mlp, make_sac_mlp, train_model

hk_init = hk.initializers
tfd = tfp.distributions

N_ITERS = 1000


def main():

    plt.rc("text", usetex=True)
    plt.rc("font", family="serif", size=9)

    np.random.seed(1)

    fig, axs = plt.subplots(
        5,
        3,
        figsize=(7.5, 4 * 1.3),
        sharex=True,
        sharey="row",
        gridspec_kw={"height_ratios": [0.02, 1, 1, 1, 1]},
    )

    def function(x):
        return 0.5 * np.sin(x - 0.2) - 0.3 * np.cos(3 * x + 0.5)

    def noisy_function(x):
        noise = np.abs(0.1 * np.sin(x)) * np.random.randn(*x.shape)
        return function(x) + noise

    n = 150
    s = np.linspace(-6, 6, n)
    a_clean = function(s)
    a_noisy = noisy_function(s)

    x = np.linspace(-20, 20, 1000)

    nm = 100
    sm = jnp.linspace(-20, 20, nm)
    am = jnp.linspace(-1, 1, nm)
    Sm, Am = jnp.meshgrid(sm, am)
    sr = Sm.ravel()
    ar = Am.ravel()

    # the ideal desired policy combines a stationary prior
    # and perfect conditionla density estimation, so we combine
    # the true stochastic process with a squared exponential covariance function.
    def kernel(x, y):
        return np.exp(-0.5 * (x - y.T) ** 2 / 0.01)

    K = kernel(s[:, None], s[:, None])
    K_ = kernel(x[:, None], s[:, None])
    mu = K_ @ np.linalg.solve(K + 1e-3 * np.eye(n), a_clean)

    # apply noise-free GP to the data
    var = kernel(x[:, None], x[:, None]) - K_ @ np.linalg.solve(
        K + 1e-3 * np.eye(n), K_.T
    )
    # add the true stochastic process inside the data range [-6, 6]
    var += np.diag(
        np.abs(0.1 * np.sin(x)) * (x > -6).astype(float) * (x < 6).astype(float)
    )
    std = np.sqrt(np.diag(var))

    upper = mu + std
    lower = mu - std
    for i in [1, 3]:
        axs[i, 0].plot(x, mu, "b")
        axs[i, 0].fill_between(
            x, upper, lower, where=upper >= lower, color="b", alpha=0.3
        )

    K_ = kernel(sr[:, None], s[:, None])
    mur = K_ @ np.linalg.solve(K + 1e-3 * np.eye(n), a_clean)
    varr = kernel(sr[:, None], sr[:, None]) - K_ @ np.linalg.solve(
        K + 1e-3 * np.eye(n), K_.T
    )
    varr += np.diag(
        np.abs(0.1 * np.sin(sr))
        * (sr > -6).astype(float)
        * (sr < 6).astype(float)
    )
    stdr = np.sqrt(np.diag(varr))
    rewards = jax.scipy.stats.norm.logpdf(
        ar, loc=mur, scale=stdr
    ) - jax.scipy.stats.norm.logpdf(ar, loc=0, scale=1)
    R_ = rewards.reshape((nm, nm))
    R_ = jnp.clip(R_, -10.0)
    for i in [2, 4]:
        cnt = axs[i, 0].contourf(
            Sm,
            Am,
            R_,
            levels=100,
        )
        for c in cnt.collections:
            c.set_edgecolor("face")

    # MLP (Faithful)
    network = hk.without_apply_rng(hk.transform(make_sac_mlp(1, [256, 256])))
    policy, params = train_model(network, s[:, None], a_noisy[:, None])
    dist = policy(x[:, None])
    y = dist.mode().squeeze()
    std = jnp.sqrt(dist.variance()).squeeze()
    upper = y + std
    lower = y - std
    axs[1, 1].plot(x, y, "b")
    axs[1, 1].fill_between(x, upper, lower, where=upper >= lower, color="b", alpha=0.3)

    rewards = policy(sr[:, None]).log_prob(ar[:, None]) - jax.scipy.stats.norm.logpdf(
        ar, loc=0, scale=1
    )
    R_ = rewards.reshape((nm, nm))
    R_ = jnp.clip(R_, -10.0)
    cnt = axs[2, 1].contourf(
        Sm,
        Am,
        R_,
        levels=100,
    )
    for c in cnt.collections:
        c.set_edgecolor("face")

    policy, params = train_model(
        network,
        s[:, None],
        a_noisy[:, None],
        refine=True,
        faithful=False,
        initial_param=params,
    )
    dist = policy(x[:, None])
    y = dist.mode().squeeze()
    std = jnp.sqrt(dist.variance()).squeeze()
    upper = y + std
    lower = y - std
    axs[3, 1].plot(x, y, "b")
    axs[3, 1].fill_between(x, upper, lower, where=upper >= lower, color="b", alpha=0.3)

    rewards = policy(sr[:, None]).log_prob(ar[:, None]) - jax.scipy.stats.norm.logpdf(
        ar, loc=0, scale=1
    )
    R_ = rewards.reshape((nm, nm))
    R_ = jnp.clip(R_, -10.0)
    cnt = axs[4, 1].contourf(
        Sm,
        Am,
        R_,
        levels=100,
    )
    for c in cnt.collections:
        c.set_edgecolor("face")

    # HETSTAT (Faithful)
    network = hk.without_apply_rng(
        hk.transform(make_hetstat_mlp(1, [256, 256, 12, 256]))
    )
    policy, params = train_model(network, s[:, None], a_noisy[:, None])
    dist = policy(x[:, None])
    y = dist.mode().squeeze()
    std = jnp.sqrt(dist.variance()).squeeze()
    upper = y + std
    lower = y - std
    axs[1, 2].plot(x, y, "b")
    axs[1, 2].fill_between(x, upper, lower, where=upper >= lower, color="b", alpha=0.3)

    rewards = policy(sr[:, None]).log_prob(ar[:, None]) - jax.scipy.stats.norm.logpdf(
        ar, loc=0, scale=1
    )
    R_ = rewards.reshape((nm, nm))
    R_ = jnp.clip(R_, -10.0)
    cnt = axs[2, 2].contourf(
        Sm,
        Am,
        R_,
        levels=100,
    )
    for c in cnt.collections:
        c.set_edgecolor("face")

    policy, params = train_model(
        network,
        s[:, None],
        a_noisy[:, None],
        refine=True,
        faithful=False,
        initial_param=params,
    )
    dist = policy(x[:, None])
    y = dist.mode().squeeze()
    std = jnp.sqrt(dist.variance()).squeeze()
    upper = y + std
    lower = y - std
    axs[3, 2].plot(x, y, "b")
    axs[3, 2].fill_between(x, upper, lower, where=upper >= lower, color="b", alpha=0.3)

    rewards = policy(sr[:, None]).log_prob(ar[:, None]) - jax.scipy.stats.norm.logpdf(
        ar, loc=0, scale=1
    )
    R_ = rewards.reshape((nm, nm))
    R_ = jnp.clip(R_, -10.0)
    cnt = axs[4, 2].contourf(
        Sm,
        Am,
        R_,
        levels=100,
    )
    for c in cnt.collections:
        c.set_edgecolor("face")

    for ax in axs[0, :]:
        ax.axis("off")
    axs[0, 0].set_title("Desired")
    axs[0, 1].set_title("Heteroscedastic MLP")
    axs[0, 2].set_title("Stationary Heteroscedastic MLP")

    for ax in axs[1, 1:]:
        ax.set_title("Initial regression")

    for ax in axs[2, 1:]:
        ax.set_title("Initial coherent reward")

    for ax in axs[3, 1:]:
        ax.set_title("Refined regression")

    for ax in axs[4, 1:]:
        ax.set_title("Refined coherent reward")

    for ax in axs[1:, 0]:
        ax.set_ylabel("$a$")

    for i in [1, 3]:
        for ax in axs[i, :]:
            ax.plot(s, a_noisy, "k.", markersize=1)

    for ax in axs.flatten():
        ax.set_xlim(-20, 20)
        ax.set_xticklabels([])
        ax.set_xticks([])
        ax.set_yticklabels([])
        ax.set_yticks([])
    for ax in axs[-1, :]:
        ax.set_xlabel("$s$")

    fig.tight_layout()
    plt.colorbar(cnt, ax=axs)
    fig.savefig("reward_refinement.pdf", bbox_inches="tight")


if __name__ == "__main__":
    main()
    plt.show()
